// Copyright (c) Six Labors.
// Licensed under the Six Labors Split License.

using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

// ReSharper disable InconsistentNaming
namespace SixLabors.ImageSharp.Tests.Formats.Png;

/// <summary>
/// This class contains reference implementations to produce verification data for unit tests
/// </summary>
internal static partial class ReferenceImplementations
{
    /// <summary>
    /// Encodes the scanline
    /// </summary>
    /// <param name="scanline">The scanline to encode</param>
    /// <param name="previousScanline">The previous scanline.</param>
    /// <param name="result">The filtered scanline result.</param>
    /// <param name="bytesPerPixel">The bytes per pixel.</param>
    /// <param name="sum">The sum of the total variance of the filtered row</param>
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static void EncodePaethFilter(ReadOnlySpan<byte> scanline, Span<byte> previousScanline, Span<byte> result, int bytesPerPixel, out int sum)
    {
        DebugGuard.MustBeSameSized(scanline, previousScanline, nameof(scanline));
        DebugGuard.MustBeSizedAtLeast(result, scanline, nameof(result));

        ref byte scanBaseRef = ref MemoryMarshal.GetReference(scanline);
        ref byte prevBaseRef = ref MemoryMarshal.GetReference(previousScanline);
        ref byte resultBaseRef = ref MemoryMarshal.GetReference(result);
        sum = 0;

        // Paeth(x) = Raw(x) - PaethPredictor(Raw(x-bpp), Prior(x), Prior(x - bpp))
        resultBaseRef = 4;

        int x = 0;
        for (; x < bytesPerPixel; /* Note: ++x happens in the body to avoid one add operation */)
        {
            byte scan = Unsafe.Add(ref scanBaseRef, x);
            byte above = Unsafe.Add(ref prevBaseRef, x);
            ++x;
            ref byte res = ref Unsafe.Add(ref resultBaseRef, x);
            res = (byte)(scan - PaethPredictor(0, above, 0));
            sum += Numerics.Abs(unchecked((sbyte)res));
        }

        for (int xLeft = x - bytesPerPixel; x < scanline.Length; ++xLeft /* Note: ++x happens in the body to avoid one add operation */)
        {
            byte scan = Unsafe.Add(ref scanBaseRef, x);
            byte left = Unsafe.Add(ref scanBaseRef, xLeft);
            byte above = Unsafe.Add(ref prevBaseRef, x);
            byte upperLeft = Unsafe.Add(ref prevBaseRef, xLeft);
            ++x;
            ref byte res = ref Unsafe.Add(ref resultBaseRef, x);
            res = (byte)(scan - PaethPredictor(left, above, upperLeft));
            sum += Numerics.Abs(unchecked((sbyte)res));
        }
    }

    /// <summary>
    /// Encodes the scanline
    /// </summary>
    /// <param name="scanline">The scanline to encode</param>
    /// <param name="result">The filtered scanline result.</param>
    /// <param name="bytesPerPixel">The bytes per pixel.</param>
    /// <param name="sum">The sum of the total variance of the filtered row</param>
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static void EncodeSubFilter(ReadOnlySpan<byte> scanline, Span<byte> result, int bytesPerPixel, out int sum)
    {
        DebugGuard.MustBeSizedAtLeast(result, scanline, nameof(result));

        ref byte scanBaseRef = ref MemoryMarshal.GetReference(scanline);
        ref byte resultBaseRef = ref MemoryMarshal.GetReference(result);
        sum = 0;

        // Sub(x) = Raw(x) - Raw(x-bpp)
        resultBaseRef = 1;

        int x = 0;
        for (; x < bytesPerPixel; /* Note: ++x happens in the body to avoid one add operation */)
        {
            byte scan = Unsafe.Add(ref scanBaseRef, x);
            ++x;
            ref byte res = ref Unsafe.Add(ref resultBaseRef, x);
            res = scan;
            sum += Numerics.Abs(unchecked((sbyte)res));
        }

        for (int xLeft = x - bytesPerPixel; x < scanline.Length; ++xLeft /* Note: ++x happens in the body to avoid one add operation */)
        {
            byte scan = Unsafe.Add(ref scanBaseRef, x);
            byte prev = Unsafe.Add(ref scanBaseRef, xLeft);
            ++x;
            ref byte res = ref Unsafe.Add(ref resultBaseRef, x);
            res = (byte)(scan - prev);
            sum += Numerics.Abs(unchecked((sbyte)res));
        }
    }

    /// <summary>
    /// Encodes the scanline
    /// </summary>
    /// <param name="scanline">The scanline to encode</param>
    /// <param name="previousScanline">The previous scanline.</param>
    /// <param name="result">The filtered scanline result.</param>
    /// <param name="sum">The sum of the total variance of the filtered row</param>
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static void EncodeUpFilter(ReadOnlySpan<byte> scanline, Span<byte> previousScanline, Span<byte> result, out int sum)
    {
        DebugGuard.MustBeSameSized(scanline, previousScanline, nameof(scanline));
        DebugGuard.MustBeSizedAtLeast(result, scanline, nameof(result));

        ref byte scanBaseRef = ref MemoryMarshal.GetReference(scanline);
        ref byte prevBaseRef = ref MemoryMarshal.GetReference(previousScanline);
        ref byte resultBaseRef = ref MemoryMarshal.GetReference(result);
        sum = 0;

        // Up(x) = Raw(x) - Prior(x)
        resultBaseRef = 2;

        int x = 0;

        for (; x < scanline.Length; /* Note: ++x happens in the body to avoid one add operation */)
        {
            byte scan = Unsafe.Add(ref scanBaseRef, x);
            byte above = Unsafe.Add(ref prevBaseRef, x);
            ++x;
            ref byte res = ref Unsafe.Add(ref resultBaseRef, x);
            res = (byte)(scan - above);
            sum += Numerics.Abs(unchecked((sbyte)res));
        }
    }

    /// <summary>
    /// Encodes the scanline
    /// </summary>
    /// <param name="scanline">The scanline to encode</param>
    /// <param name="previousScanline">The previous scanline.</param>
    /// <param name="result">The filtered scanline result.</param>
    /// <param name="bytesPerPixel">The bytes per pixel.</param>
    /// <param name="sum">The sum of the total variance of the filtered row</param>
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static void EncodeAverageFilter(ReadOnlySpan<byte> scanline, ReadOnlySpan<byte> previousScanline, Span<byte> result, int bytesPerPixel, out int sum)
    {
        DebugGuard.MustBeSameSized(scanline, previousScanline, nameof(scanline));
        DebugGuard.MustBeSizedAtLeast(result, scanline, nameof(result));

        ref byte scanBaseRef = ref MemoryMarshal.GetReference(scanline);
        ref byte prevBaseRef = ref MemoryMarshal.GetReference(previousScanline);
        ref byte resultBaseRef = ref MemoryMarshal.GetReference(result);
        sum = 0;

        // Average(x) = Raw(x) - floor((Raw(x-bpp)+Prior(x))/2)
        resultBaseRef = 3;

        int x = 0;
        for (; x < bytesPerPixel; /* Note: ++x happens in the body to avoid one add operation */)
        {
            byte scan = Unsafe.Add(ref scanBaseRef, x);
            byte above = Unsafe.Add(ref prevBaseRef, x);
            ++x;
            ref byte res = ref Unsafe.Add(ref resultBaseRef, x);
            res = (byte)(scan - (above >> 1));
            sum += Numerics.Abs(unchecked((sbyte)res));
        }

        for (int xLeft = x - bytesPerPixel; x < scanline.Length; ++xLeft /* Note: ++x happens in the body to avoid one add operation */)
        {
            byte scan = Unsafe.Add(ref scanBaseRef, x);
            byte left = Unsafe.Add(ref scanBaseRef, xLeft);
            byte above = Unsafe.Add(ref prevBaseRef, x);
            ++x;
            ref byte res = ref Unsafe.Add(ref resultBaseRef, x);
            res = (byte)(scan - Average(left, above));
            sum += Numerics.Abs(unchecked((sbyte)res));
        }
    }

    /// <summary>
    /// Calculates the average value of two bytes
    /// </summary>
    /// <param name="left">The left byte</param>
    /// <param name="above">The above byte</param>
    /// <returns>The <see cref="int"/></returns>
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static int Average(byte left, byte above) => (left + above) >> 1;

    /// <summary>
    /// Computes a simple linear function of the three neighboring pixels (left, above, upper left), then chooses
    /// as predictor the neighboring pixel closest to the computed value.
    /// </summary>
    /// <param name="left">The left neighbor pixel.</param>
    /// <param name="above">The above neighbor pixel.</param>
    /// <param name="upperLeft">The upper left neighbor pixel.</param>
    /// <returns>
    /// The <see cref="byte"/>.
    /// </returns>
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static byte PaethPredictor(byte left, byte above, byte upperLeft)
    {
        int p = left + above - upperLeft;
        int pa = Numerics.Abs(p - left);
        int pb = Numerics.Abs(p - above);
        int pc = Numerics.Abs(p - upperLeft);

        if (pa <= pb && pa <= pc)
        {
            return left;
        }

        if (pb <= pc)
        {
            return above;
        }

        return upperLeft;
    }
}
